import torch_fidelity
from PIL import Image
import os
import numpy as np
import torch
from torchvision.transforms import functional as F
from torchmetrics.image.fid import FrechetInceptionDistance

# Function to load and preprocess images
def load_images_from_directory(directory):
    image_paths = sorted([os.path.join(directory, x) for x in os.listdir(directory) if x.lower().endswith(('png', 'jpg', 'jpeg'))])
    images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]

    def preprocess_image(image):
        image = torch.tensor(image).unsqueeze(0)
        image = image.permute(0, 3, 1, 2) / 255.0
        return F.center_crop(image, (256, 256))

    return torch.cat([preprocess_image(image) for image in images])

# Paths for real and fake images
real_dataset_path = ""
fake_dataset_path = ""

# Load and preprocess images
real_images = load_images_from_directory(real_dataset_path)
fake_images = load_images_from_directory(fake_dataset_path)

print(f"Real Images Shape: {real_images.shape}")
print(f"Fake Images Shape: {fake_images.shape}")

# Compute FID
fid = FrechetInceptionDistance(normalize=True)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)

print(f"FID Score: {float(fid.compute())}")
